import argparse
import json


def str2bool(v):
    if isinstance(v, bool):
       return v
    if v.lower() in ('yes', 'true', 't', 'y', '1','True'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0','False'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def parse():
    parser = argparse.ArgumentParser()
    parser.add_argument("--env_name",type=str,default="halfcheetah")
    parser.add_argument("--env_type",type=str,default="mujoco")
    parser.add_argument("--exp_name",type=str,default="random_name")
    parser.add_argument("--image",type=str2bool,default=False)
    parser.add_argument("--deterministic_eval",type=str2bool,default=True)
    parser.add_argument("--device",type=str,default="cpu")
    parser.add_argument("--log_interval",type=int, default=4000)
    parser.add_argument("--num_updates",type=int,default=400000)
    parser.add_argument("--seed",type=int,default=0)
    parser.add_argument("--save",default="../gym_results/tests/")

    #Loading and saving currently do not work
    parser.add_argument("--load_params",type=str,default=None)
    parser.add_argument("--load_model_path",type=str,default=False)
    parser.add_argument("--save_model",type=str2bool,default=False)
    parser.add_argument("--video",type=int,default=0)
    parser.add_argument("--max_steps",type=int,default=300)

    parser.add_argument("--nlearn",type=int,default=1)
    parser.add_argument("--sub_nlearn",type=int,default=1)
    ###DRL Algorithm
    parser.add_argument("--gamma",type=float,default=0.98)
    parser.add_argument("--buffer_size",type=int,default=20000)
    parser.add_argument("--batch_size",type=int,default=512)
    parser.add_argument("--tau",type=float,default=5e-3)
    parser.add_argument("--clip_grad_sac",type=float,default=0)
    parser.add_argument("--alpha",type=float,default=2.0)
    parser.add_argument("--pi_epsilon",type=float,default=0)



    ##Neural networks
    parser.add_argument("--weight_decay",type=float,default=1e-6)
    parser.add_argument("--weight_decay2",type=float,default=1e-6)
    parser.add_argument("--lr_sac",type=float,default=5e-4)
    parser.add_argument("--neurons_sac",type=int,default=512)
    parser.add_argument("--layer_sac",type=int,default=3)

    parser.add_argument("--reward_coef",type=float,default=0.5)
    parser.add_argument("--double_critic",type=str2bool,default=True)
    parser.add_argument("--delay",type=float,default=1)
    parser.add_argument("--relabeling",type=int,default=3,help="Relabelling of B^G interactions")
    parser.add_argument("--relabeling2",type=int,default=0,help="Relabelling of B^S interactions")
    parser.add_argument("--embed_sac",type=int,default=0)

    ###Eventual warm up before learning
    parser.add_argument("--warmup",type=int,default=0)
    parser.add_argument("--stop_after_warmup",type=str2bool,default=False)

    ### OEGN parameters
    parser.add_argument("--target_prob",type=float,default=0.2)
    parser.add_argument("--num_som_updates",type=int,default=64)
    parser.add_argument("--a_threshold",type=float,default=0.6)
    parser.add_argument("--gwr_period",type=int,default=100)
    parser.add_argument("--gwr_lr",type=float,default=0.01)
    parser.add_argument("--gwr_lr2",type=float,default=0.0001)
    parser.add_argument("--gwr_tau",type=float,default=1)
    parser.add_argument("--delete_close",type=float,default=0.4)
    parser.add_argument("--ratio_for_som",type=str2bool,default=False)
    parser.add_argument("--error_max",type=int,default=600)
    parser.add_argument("--delete_ins",type=str2bool,default=False)
    parser.add_argument("--wait_to_delete",type=int,default=1)
    parser.add_argument("--a_max",type=int,default=600)
    parser.add_argument("--nodes_number",type=int,default=0)



    ###Local infoNCE
    parser.add_argument("--ratio_for_predictor", type=str2bool, default=True,help="Use the ratio to learn with LinfoNCE")
    parser.add_argument("--data_for_predictor", type=int, default=0)

    #Neural network
    parser.add_argument("--lr_pred",type=float,default=1e-4)
    parser.add_argument("--neurons_pred",type=int,default=256)
    parser.add_argument("--layer_pred",type=int,default=2)

    #positive sample
    parser.add_argument("--reg_coef",type=float,default=10.)
    parser.add_argument("--reg_pred",type=float,default=0.1)
    parser.add_argument("--square",type=str2bool,default=True)


    #Negative samples
    parser.add_argument("--tau_negative",type=float,default=1.)
    parser.add_argument("--coef_negative",type=float,default=0.5)
    parser.add_argument("--number_negative",type=int,default=10)

    #Consistency contraint
    parser.add_argument("--coef_clone_negative",type=float,default=1)
    parser.add_argument("--clone_negative",type=float,default=0.001)


    ###Coord policy
    parser.add_argument("--epsilon",type=float,default=0.01)
    parser.add_argument("--ratio_pi",type=float,default=0)
    parser.add_argument("--rew_coord_type",type=int,default=7)
    parser.add_argument("--lr_coord",type=float,default=0.05)
    parser.add_argument("--lr_coord2",type=float,default=0.05)
    parser.add_argument("--tau_coord",type=float,default=0.02)

    ###DisTop
    parser.add_argument("--skew_sample",type=float,default=0)
    parser.add_argument("--skew_select",type=float,default=0)
    parser.add_argument("--surface",type=str2bool,default=False)
    parser.add_argument("--num_latents",type=int, default=10)

    ###Planning
    parser.add_argument("--plan_interval", type=int, default=-1)
    parser.add_argument("--delta_reach",type=float, default=0.2)
    parser.add_argument("--plan_version",type=int, default=1)
    parser.add_argument("--plan_steps",type=int, default=1)
    parser.add_argument("--v2_duration",type=int, default=10)

    ###Diverse
    parser.add_argument("--state",type=str2bool, default=False)
    parser.add_argument("--type",type=int, default=0)

    args = parser.parse_args()


    if args.load_params is not None:
        with open(args.load_params,'r') as f:
            parsed_json=json.load(f)
            for namespace in parsed_json:
                for key in parsed_json[namespace]:
                    setattr(args, key, parsed_json[namespace][key])

    # import tools.consts as co
    # for const,val in vars(co).items():
    #     if const.startswith('_'):
    #         continue
    #     setattr(args,const,val)

    return args